/*=========================================================================
 *
 *  Copyright NumFOCUS
 *
 *  Licensed under the Apache License, Version 2.0 (the "License");
 *  you may not use this file except in compliance with the License.
 *  You may obtain a copy of the License at
 *
 *         https://www.apache.org/licenses/LICENSE-2.0.txt
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 *
 *=========================================================================*/

#include "itkPCAShapeSignedDistanceFunction.h"

#include "itkImageRegionIterator.h"
#include "itkEuler2DTransform.h"

#include <gtest/gtest.h>

#include <random> // For mt19937.

/**
 * This module tests the functionality of the PCAShapeSignedDistanceFunction
 * class.
 *
 * The mean image, principal component images, standard deviations, and
 * and weighting parameters are randomly generated. The signed distance is
 * evaluated at all image points and compared to expected values.
 * The test fails if the evaluated results is not within a certain tolerance
 * of the expected results.
 */
TEST(PCAShapeSignedDistanceFunction, Test)
{
  using CoordRep = double;
  constexpr unsigned int Dimension{ 2 };
  constexpr unsigned int ImageWidth{ 3 };
  constexpr unsigned int ImageHeight{ 2 };
  constexpr unsigned int NumberOfPCs{ 3 };


  // define a pca shape function
  using ShapeFunction = itk::PCAShapeSignedDistanceFunction<CoordRep, Dimension>;
  auto shape = ShapeFunction::New();
  //  shape->DebugOn();
  shape->SetNumberOfPrincipalComponents(NumberOfPCs);


  // set up the transform
  using transformType = itk::Euler2DTransform<double>;
  auto transform = transformType::New();
  shape->SetTransform(transform);


  // prepare for image creation
  using ImageType = ShapeFunction::ImageType;

  constexpr ImageType::SizeType imageSize{ ImageWidth, ImageHeight };

  const ImageType::RegionType region{ imageSize };


  // set up the random number generator
  std::mt19937                     randomNumberEngine{};
  std::normal_distribution<double> randomNumberDistribution(0.0, 1.0);

  // set up the mean image
  auto meanImage = ImageType::New();
  meanImage->SetRegions(region);
  meanImage->Allocate();

  using ImageIterator = itk::ImageRegionIterator<ImageType>;
  ImageIterator meanImageIt(meanImage, meanImage->GetBufferedRegion());

  for (meanImageIt.GoToBegin(); !meanImageIt.IsAtEnd(); ++meanImageIt)
  {
    ImageType::PixelType randomPixel = randomNumberDistribution(randomNumberEngine);
    meanImageIt.Set(randomPixel);
  }

  shape->SetMeanImage(meanImage);


  // set up the NumberOfPCs principal component images
  ShapeFunction::ImagePointerVector pcImages(NumberOfPCs);
  using ImageIteratorVector = std::vector<ImageIterator>;
  ImageIteratorVector pcImageIts(NumberOfPCs);

  for (unsigned int i = 0; i < NumberOfPCs; ++i)
  {
    pcImages[i] = ImageType::New();
    pcImages[i]->SetRegions(region);
    pcImages[i]->Allocate();

    pcImageIts[i] = ImageIterator(pcImages[i], pcImages[i]->GetBufferedRegion());

    for (pcImageIts[i].GoToBegin(); !pcImageIts[i].IsAtEnd(); ++pcImageIts[i])
    {
      ImageType::PixelType randomPixel = randomNumberDistribution(randomNumberEngine);
      pcImageIts[i].Set(randomPixel);
    }
  }

  shape->SetPrincipalComponentImages(pcImages);


  // set up the standard deviation for each principal component images
  ShapeFunction::ParametersType pcStandardDeviations(NumberOfPCs);

  for (unsigned int i = 0; i < NumberOfPCs; ++i)
  {
    pcStandardDeviations[i] = randomNumberDistribution(randomNumberEngine);
  }

  shape->SetPrincipalComponentStandardDeviations(pcStandardDeviations);


  // set up the parameters
  const unsigned int            numberOfShapeParameters = shape->GetNumberOfShapeParameters();
  const unsigned int            numberOfPoseParameters = shape->GetNumberOfPoseParameters();
  const unsigned int            numberOfParameters = numberOfShapeParameters + numberOfPoseParameters;
  ShapeFunction::ParametersType parameters(numberOfParameters);

  for (unsigned int i = 0; i < numberOfParameters; ++i)
  {
    parameters[i] = randomNumberDistribution(randomNumberEngine);
  }

  shape->SetParameters(parameters);


  // we must initialize the function before use
  shape->Initialize();

  // check pca shape calculation
  ShapeFunction::PointType point;

  constexpr unsigned int numberOfRotationParameters = Dimension * (Dimension - 1) / 2;
  const unsigned int     startIndexOfTranslationParameters = numberOfShapeParameters + numberOfRotationParameters;

  ShapeFunction::TransformType::InputPointType p;
  ShapeFunction::TransformType::InputPointType q;

  for (meanImageIt.GoToBegin(); !meanImageIt.IsAtEnd(); ++meanImageIt)
  {
    // from index to physical point
    ImageType::IndexType index = meanImageIt.GetIndex();
    meanImage->TransformIndexToPhysicalPoint(index, point);

    // inverse Euler2DTransform: first translation then rotation
    p[0] = point[0] - parameters[startIndexOfTranslationParameters];
    p[1] = point[1] - parameters[startIndexOfTranslationParameters + 1];

    const double angle = parameters[numberOfShapeParameters];
    q[0] = p[0] * std::cos(-angle) - p[1] * std::sin(-angle);
    q[1] = p[0] * std::sin(-angle) + p[1] * std::cos(-angle);

    // evaluate shape function
    ShapeFunction::OutputType output = shape->Evaluate(q);

    // calculate expected function value
    ShapeFunction::OutputType expected = meanImage->GetPixel(index);
    for (unsigned int i = 0; i < NumberOfPCs; ++i)
    {
      expected += pcImages[i]->GetPixel(index) * pcStandardDeviations[i] * parameters[i];
    }

    // check result
    std::cout << "f(" << point << ") = " << output << std::endl;

    EXPECT_NEAR(output, expected, 1e-9);
  }

  // Evaluate at a point outside the image domain
  std::cout << "Evaluate at point outside image domain" << std::endl;
  q.Fill(5.0);
  ShapeFunction::OutputType output = shape->Evaluate(q);
  std::cout << "f(" << q << ") = " << output << std::endl;

  // Exercise other methods for test coverage
  shape->Print(std::cout);

  std::cout << "NumberOfPrincipalComponents: " << shape->GetNumberOfPrincipalComponents() << std::endl;
  std::cout << "MeanImage: " << shape->GetMeanImage() << std::endl;
  std::cout << "PrincipalComponentStandardDeviations: " << shape->GetPrincipalComponentStandardDeviations()
            << std::endl;
  std::cout << "Transform: " << shape->GetTransform() << std::endl;
  std::cout << "Parameters: " << shape->GetParameters() << std::endl;

  // Exercise error testing
  bool pass = false;

#define TEST_INITIALIZATION_ERROR(ComponentName, badComponent, goodComponent) \
  shape->Set##ComponentName(badComponent);                                    \
  try                                                                         \
  {                                                                           \
    pass = false;                                                             \
    shape->Initialize();                                                      \
  }                                                                           \
  catch (const itk::ExceptionObject & err)                                    \
  {                                                                           \
    std::cout << "Caught expected ExceptionObject" << std::endl;              \
    std::cout << err << std::endl;                                            \
    pass = true;                                                              \
  }                                                                           \
  shape->Set##ComponentName(goodComponent);                                   \
                                                                              \
  EXPECT_TRUE(pass)

  // nullptr MeanImage
  TEST_INITIALIZATION_ERROR(MeanImage, nullptr, meanImage);

  // Wrong number of PC images
  ShapeFunction::ImagePointerVector badPCImages;
  badPCImages.resize(1);
  badPCImages[0] = nullptr;

  TEST_INITIALIZATION_ERROR(PrincipalComponentImages, badPCImages, pcImages);

  // A nullptr PC image
  badPCImages = pcImages;
  badPCImages[1] = nullptr;

  TEST_INITIALIZATION_ERROR(PrincipalComponentImages, badPCImages, pcImages);

  // A PC image of the wrong size
  auto                        badSize = ImageType::SizeType::Filled(1);
  const ImageType::RegionType badRegion(badSize);
  badPCImages[1] = ImageType::New();
  badPCImages[1]->SetRegions(badRegion);
  badPCImages[1]->AllocateInitialized();

  TEST_INITIALIZATION_ERROR(PrincipalComponentImages, badPCImages, pcImages);
}
